-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GPU implementation of NMSv2 op #28745
Conversation
…n of CombinedNonMaxSuppression op for Funcdef executions in TFTRT fallback path
@tfboyd This is the first part of PRs that would improve performance on object detection networks. |
Test for new op is blocked by the #28744 since GPU tensors are not correctly transferred to host without it. |
… header doesn't declare it
Hi @chsigg, could you please help to take a look at this PR? |
void NMSKernel(const Box* d_desc_sorted_boxes, const int nboxes, | ||
const float thresh, const int mask_ld, int* d_delete_mask, | ||
bool flip_boxes = false) { | ||
// Storing boxes used by this CUDA block in the shared memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments should end with a '.'.
// One 1D line load the boxes for x-dimension | ||
if (threadIdx.y == 0) { | ||
const Box box = d_desc_sorted_boxes[i_to_load]; | ||
Box flipped = box; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do this on 'box' directly with swap. It's unexpected to call this flipped when it's only flipped if flip_boxes is true.
__launch_bounds__(NMS_BLOCK_DIM* NMS_BLOCK_DIM, 4) __global__ | ||
void NMSKernel(const Box* d_desc_sorted_boxes, const int nboxes, | ||
const float thresh, const int mask_ld, int* d_delete_mask, | ||
bool flip_boxes = false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer no default arguments.
Would it help performance to make this a template parameter?
} | ||
} | ||
__syncthreads(); | ||
const int i = i_block_offset + threadIdx.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as i_to_load, no?
// both take about the same time | ||
int nto_copy = std::min(NMS_CHUNK_SIZE, N); | ||
cudaEvent_t copy_done; | ||
cudaEventCreate(©_done); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use stream_executor::Event
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@csigg I couldn't be able to find any examples of using stream_executor::Event in a similar fashion elsewhere in kernels. Even then, I don't think it is possible to implement the logic by stream_executor::Event since there is no mechanism equivalent to cudaEventSynchronize() implemented in the framework. I can try to spin on event::poll but that would be quite inefficient and would probably hinder the rest of the framework as well due to acquired locks. I would have preferred to use ThenExecute() chaining these but it would require all NMS ops to be converted to AsyncOps as well as a proper threadpool on event manager. Currently all events are executed on single thread and doing work there would block the event infrastructure. I can spawn the work on cpu device thread pool on the event callback but I am not sure if this level of complexity is justified.
How would you propose I would use stream_executor::Event, it is possible that I am missing something obvious.
explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context) | ||
: OpKernel(context) {} | ||
|
||
void Compute(OpKernelContext* context) override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment what the implementation does. Inline (like, above sections below) is also fine.
Sorry for the delay, there are some internal test failures and I'm still trying to fix them. |
PiperOrigin-RevId: 252461000
I believe this implementation is wrong: it does not agree with the CPU version of NMS Op. In this implementation, when computing the area in IOU, it uses tensorflow/tensorflow/core/kernels/non_max_suppression_op.cu.cc Lines 96 to 98 in e3062d1
However, in the CPU version, it uses tensorflow/tensorflow/core/kernels/non_max_suppression_op.cc Lines 126 to 129 in e3062d1
For many inputs this may not have an effect at all. But for certain inputs the two versions will produce inconsistent results. If your goal is to run object detection models, note that the "+1" is a legacy issue and we're trying to avoid the version with "+1" in Facebook. See this PR that handles "+1" in caffe2. |
@samikama I tried to use your kernel from inside python and I am getting a segmentation fault by running this simple script: import tensorflow as tf
tf.enable_eager_execution()
from tensorflow.python.ops import gen_image_ops
with tf.device("/device:GPU:0"):
boxes = tf.constant([[1.0, 1.0, 1.0, 1.0]], dtype=tf.float32)
scores = tf.constant([1.0], dtype=tf.float32)
max_output_size = tf.constant(10, dtype=tf.int32)
iou_threshold = tf.constant(0.7, dtype=tf.float32)
score_threshold = tf.constant(float('-inf'), dtype=tf.float32)
print("Start")
x = gen_image_ops.non_max_suppression_v2(boxes, scores, max_output_size, iou_threshold, score_threshold)
print("End")
print(x) docker run --runtime=nvidia -it -v $PWD:/tf -w /tf tensorflow/tensorflow:nightly-gpu-py3
python pyscript.py The output is
Correct me if I am doing smth wrong |
@ppwwyyxx Thanks for catching that. I made the fixes to support both legacy case and CPU identical implementation in #30893.
Another point is you are passing a box with 0 surface area and that is the only box. Even though there is a single box test in the test suite, we didn't have an invalid box test. I will add the fix for it in an upcoming PR. |
@samikama with .HostMemory() it works, thank you so much! 👍 |
Has this made it into any of the tensorflow releases? As far as I know, it wasn't included in 1.13 and 1.14. How about in tensorflow 2.0? |
This PR adds a GPU implementation of NMSv2 Op. It also registers a FakeGPU op for CombinedNonMaxSuppression op to workaround issues encountered due to lack of GPU implementation until a proper GPU implementation can be done based on current GPU kernels.